/*
 * Copyright 2019 Web3 Labs Ltd.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations under the License.
 */
package org.web3j.crypto;

import static org.bouncycastle.util.BigIntegers.TWO;
import static org.web3j.crypto.SignatureDataOperations.CHAIN_ID_INC;
import static org.web3j.crypto.SignatureDataOperations.LOWER_REAL_V;
import static org.web3j.utils.Assertions.verifyPrecondition;

import java.math.BigInteger;
import java.security.KeyFactory;
import java.security.SignatureException;
import java.util.Arrays;

import org.bouncycastle.asn1.x9.X9ECParameters;
import org.bouncycastle.asn1.x9.X9IntegerConverter;
import org.bouncycastle.crypto.digests.SM3Digest;
import org.bouncycastle.crypto.ec.CustomNamedCurves;
import org.bouncycastle.crypto.params.ECDomainParameters;
import org.bouncycastle.jce.interfaces.ECPublicKey;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.jce.spec.ECParameterSpec;
import org.bouncycastle.jce.spec.ECPublicKeySpec;
import org.bouncycastle.math.ec.ECAlgorithms;
import org.bouncycastle.math.ec.ECPoint;
import org.bouncycastle.math.ec.FixedPointCombMultiplier;
import org.bouncycastle.math.ec.custom.gm.SM2P256V1Curve;
import org.web3j.crypto.gm.SM2Signature;
import org.web3j.crypto.gm.SM2Signer;
import org.web3j.utils.Numeric;

/**
 * Transaction signing logic.
 *
 * <p>
 * Adapted from the <a href=
 * "https://github.com/bitcoinj/bitcoinj/blob/master/core/src/main/java/org/bitcoinj/core/ECKey.java">
 * BitcoinJ ECKey</a> implementation.
 */
public class Sign {

//    public static final X9ECParameters CURVE_PARAMS = CustomNamedCurves.getByName("secp256k1");
	public static final X9ECParameters CURVE_PARAMS = CustomNamedCurves.getByName("sm2p256v1");

	static final ECDomainParameters CURVE = new ECDomainParameters(CURVE_PARAMS.getCurve(), CURVE_PARAMS.getG(),
			CURVE_PARAMS.getN(), CURVE_PARAMS.getH());

	static final BigInteger HALF_CURVE_ORDER = CURVE_PARAMS.getN().shiftRight(1);

	static final String MESSAGE_PREFIX = "\u0019Ethereum Signed Message:\n";

	static byte[] getEthereumMessagePrefix(int messageLength) {
		return MESSAGE_PREFIX.concat(String.valueOf(messageLength)).getBytes();
	}

	public static byte[] getEthereumMessageHash(byte[] message) {
		byte[] prefix = getEthereumMessagePrefix(message.length);

		byte[] result = new byte[prefix.length + message.length];
		System.arraycopy(prefix, 0, result, 0, prefix.length);
		System.arraycopy(message, 0, result, prefix.length, message.length);

//		return Hash.sha3(result);
		return Hash.sm3(result);
	}

	public static SignatureData signPrefixedMessage(byte[] message, ECKeyPair keyPair) {
		return signMessage(getEthereumMessageHash(message), keyPair, false);
	}

	public static SignatureData signMessage(byte[] message, ECKeyPair keyPair) {
		return signMessage(message, keyPair, true);
	}

//	public static SignatureData signMessage(byte[] message, ECKeyPair keyPair, boolean needToHash) {
//		BigInteger publicKey = keyPair.getPublicKey();
//		byte[] messageHash;
//		if (needToHash) {
//			messageHash = Hash.sha3(message);
//		} else {
//			messageHash = message;
//		}
//		
//		ECDSASignature sig = keyPair.ecdaSign(messageHash);
//		// Now we have to work backwards to figure out the recId needed to recover the
//		// signature.
//		int recId = -1;
//		for (int i = 0; i < 4; i++) {
//			BigInteger k = recoverFromSignature(i, sig, messageHash);
//			if (k != null && k.equals(publicKey)) {
//				recId = i;
//				break;
//			}
//		}
//		if (recId == -1) {
//			throw new RuntimeException("Could not construct a recoverable key. Are your credentials valid?");
//		}
//
//		int headerByte = recId + 27;
//
//		// 1 header + 32 bytes for R + 32 bytes for S
//		byte[] v = new byte[] { (byte) headerByte };
//		byte[] r = Numeric.toBytesPadded(sig.r, 32);
//		byte[] s = Numeric.toBytesPadded(sig.s, 32);
//
//		return new SignatureData(v, r, s);
//	}

	public static SignatureData signMessage(byte[] message, ECKeyPair keyPair, boolean needToHash) {
		byte[] messageHash = null;
		if (needToHash) {
			messageHash = Hash.sm3(message);
		} else {
			messageHash = message;
		}

		SM2Signature signature = keyPair.sm2Signature(messageHash);

		byte[] v = Numeric.toBytesPadded(signature.v, 32);
		byte[] r = Numeric.toBytesPadded(signature.r, 32);
		byte[] s = Numeric.toBytesPadded(signature.s, 32);

		return new SignatureData(v, r, s);
	}

	/**
	 * Given the components of a signature and a selector value, recover and return
	 * the public key that generated the signature according to the algorithm in
	 * SEC1v2 section 4.1.6.
	 *
	 * <p>
	 * The recId is an index from 0 to 3 which indicates which of the 4 possible
	 * keys is the correct one. Because the key recovery operation yields multiple
	 * potential keys, the correct key must either be stored alongside the
	 * signature, or you must be willing to try each recId in turn until you find
	 * one that outputs the key you are expecting.
	 *
	 * <p>
	 * If this method returns null it means recovery was not possible and recId
	 * should be iterated.
	 *
	 * <p>
	 * Given the above two points, a correct usage of this method is inside a for
	 * loop from 0 to 3, and if the output is null OR a key that is not the one you
	 * expect, you try again with the next recId.
	 *
//	 * @param recId   Which possible key to recover.
//	 * @param sig     the R and S components of the signature, wrapped.
//	 * @param message Hash of the data that was signed.
//	 * @return An ECKey containing only the public part, or null if recovery wasn't
	 *         possible.
	 */
//	public static BigInteger recoverFromSignature(int recId, ECDSASignature sig, byte[] message) {
//		verifyPrecondition(recId >= 0, "recId must be positive");
//		verifyPrecondition(sig.r.signum() >= 0, "r must be positive");
//		verifyPrecondition(sig.s.signum() >= 0, "s must be positive");
//		verifyPrecondition(message != null, "message cannot be null");
//
//		// 1.0 For j from 0 to h (h == recId here and the loop is outside this function)
//		// 1.1 Let x = r + jn
//		BigInteger n = CURVE.getN(); // Curve order.
//		BigInteger i = BigInteger.valueOf((long) recId / 2);
//		BigInteger x = sig.r.add(i.multiply(n));
//		// 1.2. Convert the integer x to an octet string X of length mlen using the
//		// conversion
//		// routine specified in Section 2.3.7, where mlen = ⌈(log2 p)/8⌉ or mlen =
//		// ⌈m/8⌉.
//		// 1.3. Convert the octet string (16 set binary digits)||X to an elliptic curve
//		// point R
//		// using the conversion routine specified in Section 2.3.4. If this conversion
//		// routine outputs "invalid", then do another iteration of Step 1.
//		//
//		// More concisely, what these points mean is to use X as a compressed public
//		// key.
//		BigInteger prime = SecP256K1Curve.q;
//		if (x.compareTo(prime) >= 0) {
//			// Cannot have point co-ordinates larger than this as everything takes place
//			// modulo Q.
//			return null;
//		}
//		// Compressed keys require you to know an extra bit of data about the y-coord as
//		// there are
//		// two possibilities. So it's encoded in the recId.
//		ECPoint R = decompressKey(x, (recId & 1) == 1);
//		// 1.4. If nR != point at infinity, then do another iteration of Step 1 (callers
//		// responsibility).
//		if (!R.multiply(n).isInfinity()) {
//			return null;
//		}
//		// 1.5. Compute e from M using Steps 2 and 3 of ECDSA signature verification.
//		BigInteger e = new BigInteger(1, message);
//		// 1.6. For k from 1 to 2 do the following. (loop is outside this function via
//		// iterating recId)
//		// 1.6.1. Compute a candidate public key as:
//		// Q = mi(r) * (sR - eG)
//		//
//		// Where mi(x) is the modular multiplicative inverse. We transform this into the
//		// following:
//		// Q = (mi(r) * s ** R) + (mi(r) * -e ** G)
//		// Where -e is the modular additive inverse of e, that is z such that z + e = 0
//		// (mod n).
//		// In the above equation ** is point multiplication and + is point addition (the
//		// EC group
//		// operator).
//		//
//		// We can find the additive inverse by subtracting e from zero then taking the
//		// mod. For
//		// example the additive inverse of 3 modulo 11 is 8 because 3 + 8 mod 11 = 0,
//		// and
//		// -3 mod 11 = 8.
//		BigInteger eInv = BigInteger.ZERO.subtract(e).mod(n);
//		BigInteger rInv = sig.r.modInverse(n);
//		BigInteger srInv = rInv.multiply(sig.s).mod(n);
//		BigInteger eInvrInv = rInv.multiply(eInv).mod(n);
//		ECPoint q = ECAlgorithms.sumOfTwoMultiplies(CURVE.getG(), eInvrInv, R, srInv);
//
//		byte[] qBytes = q.getEncoded(false);
//		// We remove the prefix
//		return new BigInteger(1, Arrays.copyOfRange(qBytes, 1, qBytes.length));
//	}

	public static ECPoint recoverECPoint(SM2Signature signature, int iter) {
		// 1.0 For j from 0 to h (h == recId here and the loop is outside this function)
		// 1.1 Let x = r + jn
		// 1.1 x = (n * i) + r - e
		BigInteger n = CURVE.getN(); // Curve order.
		BigInteger i = BigInteger.valueOf((long) iter / 2);
		BigInteger x = signature.r.add(i.multiply(n)).subtract(signature.v);

		BigInteger prime = SM2P256V1Curve.q;
		if (x.compareTo(prime) >= 0) {
			// Cannot have point co-ordinates larger than this as everything takes place
			// modulo Q.
			return null;
		}
		// Compressed keys require you to know an extra bit of data about the y-coord as
		// there are
		// two possibilities. So it's encoded in the recId.
		ECPoint R = decompressKey(x, (iter & 1) == 1);
		// 1.4. If nR != point at infinity, then do another iteration of Step 1 (callers
		// responsibility).
		if (R == null || !R.multiply(n).isInfinity()) {
			return null;
		}
		return R;
	}

	public static BigInteger recoverFromSignature(SM2Signature sig, byte[] message) {
		verifyPrecondition(sig.v.signum() >= 0, "v must be positive");
		verifyPrecondition(sig.r.signum() >= 0, "r must be positive");
		verifyPrecondition(sig.s.signum() >= 0, "s must be positive");
		verifyPrecondition(message != null, "message cannot be null");
		
		BigInteger n = CURVE.getN(); // Curve order.
		ECPoint r = null;
		for (int i = 0; i < 4; i++) {
			r= recoverECPoint(sig, i);

			if (r == null) {
				continue;
			}
			// invr := new(big.Int).ModInverse(new(big.Int).Add(sig.S, sig.R), curve.Params().N)
			BigInteger invR = sig.r.add(sig.s).modInverse(n);
//			ECPoint R = combMultiplier.multiply(R0, invR);
//			s := new(big.Int).Set(sig.S)
//					s.Neg(s)
//					s.Mod(s, curve.Params().N)
//					s.Mul(s, invr)
//					s.Mod(s, curve.Params().N)
//					minuseGx, minuseGy := curve.ScalarBaseMult(s.Bytes())
			BigInteger sInvR = BigInteger.ZERO.subtract(sig.s).mod(n).multiply(invR).mod(n);
//			ECPoint g = combMultiplier.multiply(CURVE.getG(), ss);
			//Qx, Qy := curve.Add(sRx, sRy, minuseGx, minuseGy)
			ECPoint q = ECAlgorithms.sumOfTwoMultiplies(CURVE.getG(), sInvR, r, invR).normalize();
			
			try {
				// 计算E
				BigInteger e = SM2Signer.calculateE(new SM3Digest(), CURVE, q, null, message);
				if (e.compareTo(sig.v) == 0) {
					byte[] qBytes = q.getEncoded(false);
					// We remove the prefix
					BigInteger k = new BigInteger(1, Arrays.copyOfRange(qBytes, 1, qBytes.length));
					return k;
				}
			} catch (Exception e) {
				e.printStackTrace();
			}

		}

		return null;

	}

	/** Decompress a compressed public key (x co-ord and low-bit of y-coord). */
	private static ECPoint decompressKey(BigInteger xBN, boolean yBit) {
		try {
			X9IntegerConverter x9 = new X9IntegerConverter();
			byte[] compEnc = x9.integerToBytes(xBN, 1 + x9.getByteLength(CURVE.getCurve()));
			compEnc[0] = (byte) (yBit ? 0x03 : 0x02);
			return CURVE.getCurve().decodePoint(compEnc);
		} catch (Exception e) {
			return null;
		}
	}

	public static ECPublicKey createPublicKey(ECDomainParameters ecDomainParameters, BigInteger x, BigInteger y) throws Exception {
		KeyFactory keyFactory = KeyFactory.getInstance("EC", BouncyCastleProvider.PROVIDER_NAME);
		ECParameterSpec ecps = new ECParameterSpec(ecDomainParameters.getCurve(), //
				ecDomainParameters.getG(), ecDomainParameters.getN());
		ECPoint q = ecDomainParameters.getCurve().createPoint(x, y); // 创建公钥点
		ECPublicKeySpec keySpec = new ECPublicKeySpec(q, ecps);
		ECPublicKey publicKey = (ECPublicKey) keyFactory.generatePublic(keySpec);
		
		
		return publicKey;
	}

	/**
	 * Given an arbitrary piece of text and an Ethereum message signature encoded in
	 * bytes, returns the public key that was used to sign it. This can then be
	 * compared to the expected public key to determine if the signature was
	 * correct.
	 *
	 * @param message       RLP encoded message.
	 * @param signatureData The message signature components
	 * @return the public key used to sign the message
	 * @throws SignatureException If the public key could not be recovered or if
	 *                            there was a signature format error.
	 */
//	public static BigInteger signedMessageToKey(byte[] message, SignatureData signatureData) throws SignatureException {
//		return signedMessageHashToKey(Hash.sha3(message), signatureData);
//	}

	public static BigInteger signedMessageToKey(byte[] message, SignatureData signatureData) throws SignatureException {
		return signedMessageHashToKey(Hash.sm3(message), signatureData);
	}

	/**
	 * Given an arbitrary message and an Ethereum message signature encoded in
	 * bytes, returns the public key that was used to sign it. This can then be
	 * compared to the expected public key to determine if the signature was
	 * correct.
	 *
	 * @param message       The message.
	 * @param signatureData The message signature components
	 * @return the public key used to sign the message
	 * @throws SignatureException If the public key could not be recovered or if
	 *                            there was a signature format error.
	 */
	public static BigInteger signedPrefixedMessageToKey(byte[] message, SignatureData signatureData)
			throws SignatureException {
		return signedMessageHashToKey(getEthereumMessageHash(message), signatureData);
	}

	/**
	 * Given an arbitrary message hash and an Ethereum message signature encoded in
	 * bytes, returns the public key that was used to sign it. This can then be
	 * compared to the expected public key to determine if the signature was
	 * correct.
	 *
	 * @param messageHash   The message hash.
	 * @param signatureData The message signature components
	 * @return the public key used to sign the message
	 * @throws SignatureException If the public key could not be recovered or if
	 *                            there was a signature format error.
	 */
//	public static BigInteger signedMessageHashToKey(byte[] messageHash, SignatureData signatureData)
//			throws SignatureException {
//
//		byte[] r = signatureData.getR();
//		byte[] s = signatureData.getS();
//		verifyPrecondition(r != null && r.length == 32, "r must be 32 bytes");
//		verifyPrecondition(s != null && s.length == 32, "s must be 32 bytes");
//
//		int header = signatureData.getV()[0] & 0xFF;
//		// The header byte: 0x1B = first key with even y, 0x1C = first key with odd y,
//		// 0x1D = second key with even y, 0x1E = second key with odd y
//		if (header < 27 || header > 34) {
//			throw new SignatureException("Header byte out of range: " + header);
//		}
//
//		ECDSASignature sig = new ECDSASignature(new BigInteger(1, signatureData.getR()),
//				new BigInteger(1, signatureData.getS()));
//
//		int recId = header - 27;
//		BigInteger key = recoverFromSignature(recId, sig, messageHash);
//		if (key == null) {
//			throw new SignatureException("Could not recover public key from signature");
//		}
//		return key;
//	}

	public static BigInteger signedMessageHashToKey(byte[] messageHash, SignatureData signatureData)
			throws SignatureException {

		byte[] r = signatureData.getR();
		byte[] s = signatureData.getS();
		byte[] v = signatureData.getV();

		verifyPrecondition(r != null && r.length == 32, "r must be 32 bytes");
		verifyPrecondition(s != null && s.length == 32, "s must be 32 bytes");
		verifyPrecondition(v != null && v.length == 32, "v must be 32 bytes");

		SM2Signature sig = new SM2Signature(new BigInteger(1, r), new BigInteger(1, s), new BigInteger(1, v));

		BigInteger key = recoverFromSignature(sig, messageHash);
		if (key == null) {
			throw new SignatureException("Could not recover public key from signature");
		}
		return key;
	}

	/**
	 * Returns recovery ID.
	 *
	 * @param signatureData The message signature components
	 * @param chainId       of the network
	 * @return int recovery ID
	 */
	@Deprecated
	public static int getRecId(SignatureData signatureData, long chainId) {
		BigInteger v = Numeric.toBigInt(signatureData.getV());
		BigInteger lowerRealV = BigInteger.valueOf(LOWER_REAL_V);
		BigInteger lowerRealVPlus1 = BigInteger.valueOf(LOWER_REAL_V + 1);
		BigInteger chainIdInc = BigInteger.valueOf(CHAIN_ID_INC);
		if (v.equals(lowerRealV) || v.equals(lowerRealVPlus1)) {
			return v.subtract(lowerRealV).intValue();
		} else if (v.compareTo(chainIdInc) > 0) {
			return v.subtract(BigInteger.valueOf(chainId).multiply(TWO)).add(chainIdInc).intValue();
		} else {
			throw new RuntimeException(String.format("Unsupported format exception", v));
		}
	}

	/**
	 * Returns the header 'v'.
	 *
	 * @param recId The recovery id.
	 * @return byte[] header 'v'.
	 */
	@Deprecated
	public static byte[] getVFromRecId(int recId) {
		return new byte[] { (byte) (LOWER_REAL_V + recId) };
	}

	/**
	 * Returns public key from the given private key.
	 *
	 * @param privKey the private key to derive the public key from
	 * @return BigInteger encoded public key
	 */
	public static BigInteger publicKeyFromPrivate(BigInteger privKey) {
		ECPoint point = publicPointFromPrivate(privKey);

		byte[] encoded = point.getEncoded(false);
		return new BigInteger(1, Arrays.copyOfRange(encoded, 1, encoded.length)); // remove prefix
	}

	/**
	 * Returns public key point from the given private key.
	 *
	 * @param privKey the private key to derive the public key from
	 * @return ECPoint public key
	 */
	public static ECPoint publicPointFromPrivate(BigInteger privKey) {
		/*
		 * TODO: FixedPointCombMultiplier currently doesn't support scalars longer than
		 * the group order, but that could change in future versions.
		 */
		if (privKey.bitLength() > CURVE.getN().bitLength()) {
			privKey = privKey.mod(CURVE.getN());
		}
		return new FixedPointCombMultiplier().multiply(CURVE.getG(), privKey);
	}

	public static byte[] getPublicKeyBytes(BigInteger pubKey) {
		byte[] pubData = pubKey.toByteArray();
		if (pubData.length == 65) {
			pubData[0] = 4;
			return pubData;
		}
		byte[] publicKeyBytes = new byte[pubData.length + 1];
		System.arraycopy(pubData, 0, publicKeyBytes, 1, pubData.length);
		publicKeyBytes[0] = 4;

		return publicKeyBytes;
	}

	public static ECPoint publicPointFromPublic(BigInteger publicKey) {
		return Sign.CURVE.getCurve().decodePoint(getPublicKeyBytes(publicKey));
	}

	/**
	 * Returns public key point from the given curve.
	 *
	 * @param bits representing the point on the curve
	 * @return BigInteger encoded public key
	 */
	public static BigInteger publicFromPoint(byte[] bits) {
		return new BigInteger(1, Arrays.copyOfRange(bits, 1, bits.length)); // remove prefix
	}

	public static class SignatureData {
		private final byte[] v;
		private final byte[] r;
		private final byte[] s;

		public SignatureData(byte v, byte[] r, byte[] s) {
			this(new byte[] { v }, r, s);
		}

		public SignatureData(byte[] v, byte[] r, byte[] s) {
			this.v = v;
			this.r = r;
			this.s = s;
		}

		public byte[] getV() {
			return v;
		}

		public byte[] getR() {
			return r;
		}

		public byte[] getS() {
			return s;
		}

		@Override
		public boolean equals(Object o) {
			if (this == o) {
				return true;
			}
			if (o == null || getClass() != o.getClass()) {
				return false;
			}

			SignatureData that = (SignatureData) o;

			if (!Arrays.equals(v, that.v)) {
				return false;
			}
			if (!Arrays.equals(r, that.r)) {
				return false;
			}
			return Arrays.equals(s, that.s);
		}

		@Override
		public int hashCode() {
			int result = Arrays.hashCode(v);
			result = 31 * result + Arrays.hashCode(r);
			result = 31 * result + Arrays.hashCode(s);
			return result;
		}
	}
}
